# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.compiler import ir
from iree.compiler.dialects import iree_codegen
from iree.compiler.dialects import affine
from iree.compiler.ir import AffineMap, AffineDimExpr


def run(fn):
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        with ir.InsertionPoint(module.body):
            print("\nTEST:", fn.__name__)
            fn()
    return fn


@run
def root_op():
    module_str = """
        module {
            func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %cst = arith.constant 0.000000e+00 : f32
                %0 = tensor.empty() : tensor<4x4xf32>
                %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
                %2 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %2 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    assert input_module is not None, "Failed to parse input MLIR module"
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 0

    module_str = """
        module {
            func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %cst = arith.constant 0.000000e+00 : f32
                %0 = tensor.empty() : tensor<4x4xf32>
                %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
                %2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %2 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    assert input_module is not None, "Failed to parse input MLIR module"
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1
    assert root_op_list[0].name == "linalg.matmul"

    module_str = """
        module {
            func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %cst = arith.constant 0.000000e+00 : f32
                %0 = tensor.empty() : tensor<4x4xf32>
                %1 = linalg.fill { root_op } ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
                %2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %2 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    assert input_module is not None, "Failed to parse input MLIR module"
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 2
    assert root_op_list[0].name == "linalg.fill"
    assert root_op_list[1].name == "linalg.matmul"


@run
def attention_op_detail():
    dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)]

    q_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]])
    k_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]])
    v_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]])
    o_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]])

    result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)

    assert result.domain_rank == 5
    assert result.batch_dims == [0]
    assert result.m_dims == [1]
    assert result.k1_dims == [2]
    assert result.k2_dims == [3]
    assert result.n_dims == [4]

    dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)]

    # Input affine maps that do not follow the expected pattern for an attention operation.
    q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
    k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]])
    v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]])
    o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])

    result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
    assert result.domain_rank == 4
    assert result.batch_dims == [0]
    assert result.m_dims == [1]
    assert result.k1_dims == []
    assert result.k2_dims == [2]
    assert result.n_dims == [3]


@run
def test_isa_attention_op():
    module_str = """
        module {
               func.func @attention_20x4096x64x4096x64(
                    %q : tensor<20x4096x64xf16>,
                    %k : tensor<20x4096x64xf16>,
                    %v : tensor<20x4096x64xf16>,
                    %scale : f16,
                    %output : tensor<20x4096x64xf16>
                ) -> tensor<20x4096x64xf16> {
                    %result = iree_linalg_ext.attention { root_op,
                        indexing_maps = [
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
                        affine_map<(d0, d1, d2, d3, d4) -> ()>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
                        ]
                    } ins(%q, %k, %v, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
                        outs(%output : tensor<20x4096x64xf16>) {
                    ^bb0(%score: f32):
                        iree_linalg_ext.yield %score : f32
                    } -> tensor<20x4096x64xf16>
                    return %result : tensor<20x4096x64xf16>
                }
        }
    """
    input_module = ir.Module.parse(module_str)
    assert input_module is not None, "Failed to parse input MLIR module"
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1
    assert root_op_list[0].name == "iree_linalg_ext.attention"
    assert iree_codegen.isa_attention_op(root_op_list[0])


@run
def test_igemm_conv_details():
    # Test 1: conv_2d_nhwc_hwcf.
    module_str = """
        module {
            func.func @conv_2d_nhwc_hwcf(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
                %0 = linalg.conv_2d_nhwc_hwcf { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
                    ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
                    outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
                return %0 : tensor<1x14x14x16xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)

    details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
    assert details is not None, "IGEMM details should be valid for NHWC_HWCF conv"
    assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]

    assert len(details.igemm_contraction_maps) == 3
    maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
    d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
    # For channel-last (NHWC): input (N,H,W,K), filter (K,OC), output (N,H,W,OC).
    assert maps[0] == AffineMap.get(
        5, 0, [d0, d1, d2, d4]
    ), f"Input map mismatch: {maps[0]}"
    assert maps[1] == AffineMap.get(5, 0, [d4, d3]), f"Filter map mismatch: {maps[1]}"
    assert maps[2] == AffineMap.get(
        5, 0, [d0, d1, d2, d3]
    ), f"Output map mismatch: {maps[2]}"
    iter_types = [str(attr) for attr in details.igemm_loop_iterators]
    assert iter_types == [
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"reduction"',
    ]
    assert details.im2col_output_perm == [0, 1, 2, 3]
    assert details.filter_reassoc_indices == [[0, 1, 2], [3]]
    assert not details.is_output_channel_first
    assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}

    # Test 2: conv_2d_nhwc_fhwc.
    module_str = """
        module {
            func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
                %0 = linalg.conv_2d_nhwc_fhwc { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
                    ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
                    outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
                return %0 : tensor<1x14x14x16xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)

    details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
    assert details is not None, "IGEMM details should be valid for NHWC_FHWC conv"
    assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]
    assert len(details.igemm_contraction_maps) == 3
    maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
    # Verify expected affine maps (NHWC_FHWC layout).
    d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
    assert maps[0] == AffineMap.get(
        5, 0, [d0, d1, d2, d4]
    ), f"Input map mismatch: {maps[0]}"
    assert maps[1] == AffineMap.get(5, 0, [d3, d4]), f"Filter map mismatch: {maps[1]}"
    assert maps[2] == AffineMap.get(
        5, 0, [d0, d1, d2, d3]
    ), f"Output map mismatch: {maps[2]}"
    iter_types = [str(attr) for attr in details.igemm_loop_iterators]
    assert iter_types == [
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"reduction"',
    ]
    assert details.im2col_output_perm == [0, 1, 2, 3]
    assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
    assert not details.is_output_channel_first
    assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}

    # Test 3: conv_2d_nchw_fchw.
    module_str = """
        module {
            func.func @conv_2d_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
                %0 = linalg.conv_2d_nchw_fchw { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
                    ins(%arg0, %arg1 : tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
                    outs(%arg2 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
                return %0 : tensor<1x16x14x14xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)

    details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
    assert details is not None, "IGEMM details should be valid for NCHW conv"
    assert details.igemm_loop_bounds == [1, 16, 14, 14, 36]
    assert len(details.igemm_contraction_maps) == 3
    maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
    # Verify expected affine maps for NCHW with loop dims [N, OC, H, W, K].
    # Note: operands are swapped - filter first, then input.
    d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
    assert maps[0] == AffineMap.get(5, 0, [d1, d4]), f"Filter map mismatch: {maps[0]}"
    assert maps[1] == AffineMap.get(
        5, 0, [d0, d2, d3, d4]
    ), f"Input map mismatch: {maps[1]}"
    assert maps[2] == AffineMap.get(
        5, 0, [d0, d1, d2, d3]
    ), f"Output map mismatch: {maps[2]}"
    iter_types = [str(attr) for attr in details.igemm_loop_iterators]
    assert iter_types == [
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"parallel"',
        '"reduction"',
    ]
    assert details.im2col_output_perm == [0, 1, 2, 3]
    assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
    assert details.is_output_channel_first
    assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}

    # Test 4: linalg.generic with convolution pattern (weight backward).
    module_str = """
        module {
            func.func @conv_generic_weight_backward(%arg0: tensor<16x98x64x96xf32>, %arg1: tensor<16x96x64x96xf32>, %arg2: tensor<96x3x96xf32>) -> tensor<96x3x96xf32> {
                %0 = linalg.generic {
                    indexing_maps = [
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d1 + d4, d5, d2)>,
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5, d0)>,
                        affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
                    ],
                    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
                } ins(%arg0, %arg1 : tensor<16x98x64x96xf32>, tensor<16x96x64x96xf32>) outs(%arg2 : tensor<96x3x96xf32>) attrs = {root_op} {
                ^bb0(%in: f32, %in_1: f32, %out: f32):
                    %mul = arith.mulf %in, %in_1 : f32
                    %add = arith.addf %out, %mul : f32
                    linalg.yield %add : f32
                } -> tensor<96x3x96xf32>
                return %0 : tensor<96x3x96xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)

    details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
    assert (
        details is not None
    ), "IGEMM details should be valid for generic 1D conv weight backward"
    assert details.igemm_loop_bounds == [96, 3, 96, 98304]
    assert len(details.igemm_contraction_maps) == 3
    maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
    d0, d1, d2, d3 = [AffineDimExpr.get(i) for i in range(4)]
    assert maps[0] == AffineMap.get(4, 0, [d3, d0]), f"Map 0 mismatch: {maps[0]}"
    assert maps[1] == AffineMap.get(4, 0, [d1, d3, d2]), f"Map 1 mismatch: {maps[1]}"
    assert maps[2] == AffineMap.get(4, 0, [d0, d1, d2]), f"Map 2 mismatch: {maps[2]}"
    iter_types = [str(attr) for attr in details.igemm_loop_iterators]
    assert iter_types == ['"parallel"', '"parallel"', '"parallel"', '"reduction"']
    assert details.im2col_output_perm == [1, 2, 0]
    assert details.filter_reassoc_indices == [[0, 1, 2], [3]]
    assert details.is_output_channel_first
    assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 3}

    # Test with a non-conv operation.
    module_str = """
        module {
            func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %0 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    matmul_op = root_op_list[0]

    details = iree_codegen.get_igemm_generic_conv_details(matmul_op)
    assert details is None, "IGEMM details should be None for non-conv operation"


@run
def test_isa_scaled_contraction_op():
    # Test 1: Regular matmul is not a scaled contraction.
    module_str = """
        module {
            func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %0 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    assert input_module is not None, "Failed to parse input MLIR module"
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1
    matmul_op = root_op_list[0]

    assert not iree_codegen.isa_scaled_contraction_op(
        matmul_op
    ), "Regular matmul should not be a scaled contraction"

    # Test 2: Fill op is not a scaled contraction.
    module_str = """
        module {
            func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %cst = arith.constant 0.000000e+00 : f32
                %0 = linalg.fill { root_op } ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %0 : tensor<4x4xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1
    fill_op = root_op_list[0]

    assert not iree_codegen.isa_scaled_contraction_op(
        fill_op
    ), "Fill op should not be a scaled contraction"

    # Test 3: Scaled matmul as linalg.generic should be detected.
    # Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output),
    # and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb).
    # Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern).
    module_str = """
        module {
            func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
                                     %lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
                                     %out: tensor<16x16xf32>) -> tensor<16x16xf32> {
                %result = linalg.generic {
                    indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1)>
                    ],
                    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
                    root_op
                } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
                  outs(%out : tensor<16x16xf32>) {
                ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
                    %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
                    %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
                    %prod = arith.mulf %a_scaled, %b_scaled : f32
                    %sum = arith.addf %acc, %prod : f32
                    linalg.yield %sum : f32
                } -> tensor<16x16xf32>
                return %result : tensor<16x16xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1, "Should have one root op"

    scaled_generic_op = root_op_list[0]
    is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op)
    assert (
        is_scaled
    ), "linalg.generic with scaled matmul pattern should be detected as scaled contraction"

    dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op)
    assert dims is not None, "Should be able to infer dimensions for scaled contraction"

    assert dims.m == [0], f"Got {dims.m}"
    assert dims.n == [1], f"Got {dims.n}"
    assert dims.k == [2], f"Got {dims.k}"
    assert dims.kB == [3], f"Got {dims.kB}"
    assert dims.batch == [], f"Got {dims.batch}"


@run
def test_infer_scaled_contraction_dimensions():
    # Test 1: Verify dimension inference on a scaled matmul operation.
    module_str = """
        module {
            func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
                                     %lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
                                     %out: tensor<16x16xf32>) -> tensor<16x16xf32> {
                %result = linalg.generic {
                    indexing_maps = [
                        affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
                        affine_map<(d0, d1, d2, d3) -> (d0, d1)>
                    ],
                    iterator_types = ["parallel", "parallel", "reduction", "reduction"],
                    root_op
                } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
                  outs(%out : tensor<16x16xf32>) {
                ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
                    %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
                    %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
                    %prod = arith.mulf %a_scaled, %b_scaled : f32
                    %sum = arith.addf %acc, %prod : f32
                    linalg.yield %sum : f32
                } -> tensor<16x16xf32>
                return %result : tensor<16x16xf32>
            }
        }
    """
    input_module = ir.Module.parse(module_str)
    root_op_list = iree_codegen.get_tuner_root_ops(input_module)
    assert len(root_op_list) == 1, "Should have exactly one root op"
    scaled_op = root_op_list[0]

    assert iree_codegen.isa_scaled_contraction_op(
        scaled_op
    ), "Operation should be recognized as scaled contraction"

    dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op)
    assert dims is not None, "Should successfully infer dimensions"
    assert dims.m == [0], f"Got {dims.m}"
    assert dims.n == [1], f"Got {dims.n}"
    assert dims.k == [2], f"Got {dims.k}"
    assert dims.kB == [3], f"Got {dims.kB}"
    assert dims.batch == [], f"Got {dims.batch}"

    # Test 2: Non-scaled contraction should return None.
    module_str_regular = """
        module {
            func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
                %0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
                return %0 : tensor<4x4xf32>
            }
        }
    """
    input_module_regular = ir.Module.parse(module_str_regular)
    regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular)
    assert len(regular_ops) == 1
    regular_matmul = regular_ops[0]

    # Regular matmul should not have scaled contraction dimensions.
    # Check if all dimensions are empty (indicating it's not a scaled contraction).
    dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul)
    if dims_regular is not None:
        all_empty = (
            len(dims_regular.m) == 0
            and len(dims_regular.n) == 0
            and len(dims_regular.k) == 0
            and len(dims_regular.kB) == 0
            and len(dims_regular.batch) == 0
        )
        assert (
            all_empty or dims_regular is None
        ), "Regular matmul should not have valid scaled contraction dimensions"

    # Test 3: Batched scaled matmul.
    module_str_batched = """
        module {
            func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>,
                                             %lhs_scales: tensor<8x16x4xf8E8M0FNU>, %rhs_scales: tensor<8x16x4xf8E8M0FNU>,
                                             %out: tensor<8x16x16xf32>) -> tensor<8x16x16xf32> {
                %result = linalg.generic {
                    indexing_maps = [
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
                    ],
                    iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"],
                    root_op
                } ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4xf8E8M0FNU>, tensor<8x16x4xf8E8M0FNU>)
                  outs(%out : tensor<8x16x16xf32>) {
                ^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
                    %a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
                    %b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
                    %prod = arith.mulf %a_scaled, %b_scaled : f32
                    %sum = arith.addf %acc, %prod : f32
                    linalg.yield %sum : f32
                } -> tensor<8x16x16xf32>
                return %result : tensor<8x16x16xf32>
            }
        }
    """
    input_module_batched = ir.Module.parse(module_str_batched)
    batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched)
    assert len(batched_ops) == 1, "Batched op should be found"
    batched_op = batched_ops[0]
    assert iree_codegen.isa_scaled_contraction_op(
        batched_op
    ), "Batched scaled matmul should be recognized"

    dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op)
    assert (
        dims_batched is not None
    ), "Batch dimension must be present in batched scaled matmul"
    assert dims_batched.batch == [0], f"Got {dims_batched.batch}"
    assert dims_batched.m == [1], f"Got {dims_batched.m}"
    assert dims_batched.n == [2], f"Got {dims_batched.n}"
    assert dims_batched.k == [3], f"Got {dims_batched.k}"
    assert dims_batched.kB == [4], f"Got {dims_batched.kB}"
